import numpy as np
import logging
import pathlib
import xml.etree.ElementTree as ET
import cv2
import os


class SKUDataset:

    def __init__(self, root, transform=None, target_transform=None, mode=0, keep_difficult=False,
                 label_file=None):
        """Dataset for SKU data.
        Args:
            root: the root of the 30 SKU dataset, the directory contains the following sub-directories:
                Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject.
            mode: 0-train 1-val 2-test. default=0.
        """
        self.root = pathlib.Path(root)
        self.transform = transform
        self.target_transform = target_transform
        if int(mode) == 0:
            image_sets_file = self.root / "ImageSets/Main/train.txt"
        elif int(mode) == 1:
            image_sets_file = self.root / "ImageSets/Main/val.txt"
        elif int(mode) == 2:
            image_sets_file = self.root / "ImageSets/Main/test.txt"
        else:
            print(f"Dataset mode {mode} is not supported.")
            exit(1)
        self.ids = SKUDataset._read_image_ids(image_sets_file)
        self.keep_difficult = keep_difficult

        # if the labels file exists, read in the class names
        label_file_name = self.root / "labels.txt"

        if os.path.isfile(label_file_name):
            class_string = ""
            with open(label_file_name, 'r') as infile:
                for line in infile:
                    class_string += line.rstrip()
            # classes should be a comma separated list

            classes = class_string.split(',')
            # prepend BACKGROUND as first class
            classes.insert(0, 'BACKGROUND')
            classes = [elem.replace(" ", "") for elem in classes]
            self.class_names = tuple(classes)
            # logging.info("SKU Labels read from file: " + str(len(self.class_names)))

        else:
            logging.info("No labels file, using default SKU classes.")
            self.class_names = ('BACKGROUND',
                                'Ajiniuruwei', 'anganzhengzhuanhuashengren', 'aoliaojiaxinbinggan',
                                'asamunaicha', 'baiyuheidoujiang', 'beibingyangjuqi',
                                'doufenyizusuanlafen', 'haoliyoulangligelang', 'hunheguorenangan',
                                'huorunmangguosuannai', 'huoxingrusuanyuanwei', 'jiadelechengwei',
                                'jiadunyuanweisuda', 'jianjiaolvping', 'jiayuanguobaniurouwei',
                                'junzaijuanmianjinxianglawei170g', 'kaqiguobamalalongxiawei',
                                'kekouxianwei', 'lekasiweihuajuan', 'maidongqingningkouwei',
                                'meizhiyuanguolicheng', 'moqitaozhiyinliao', 'nongfushanquanshui',
                                'quecaonatie', 'rusuandaogao', 'sandelinuancheng', 'tongyilaotansuancaimian',
                                'xiangpiaopiaoniurucha', 'yiquanningmengqishui', 'zuocansuannaiqingningwei')


        self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)}

    def __getitem__(self, index):
        image_id = self.ids[index]
        if self.transform:
            boxes, labels, is_difficult = self._get_annotation(image_id)
        if not self.keep_difficult and self.transform:
            boxes = boxes[is_difficult == 0]
            labels = labels[is_difficult == 0]
        image = self._read_image(image_id)
        if self.transform:
            image, boxes, labels = self.transform(image, boxes, labels)
        if self.target_transform:
            boxes, labels = self.target_transform(boxes, labels)
        return (image, boxes, labels) if self.transform else image

    def get_image(self, index):
        image_id = self.ids[index]
        image = self._read_image(image_id)
        if self.transform:
            image, _ = self.transform(image)
        return image

    def get_annotation(self, index):
        image_id = self.ids[index]
        return image_id, self._get_annotation(image_id)

    def __len__(self):
        return len(self.ids)

    @staticmethod
    def _read_image_ids(image_sets_file):
        ids = []
        with open(image_sets_file) as f:
            for line in f:
                ids.append(line.rstrip())
        return ids

    def _get_annotation(self, image_id):
        annotation_file = self.root / f"Annotations/{image_id}.xml"
        # print('get annotation file : ', annotation_file)
        objects = ET.parse(annotation_file).findall("object")
        # print('get annotation objects : ', objects)
        boxes = []
        labels = []
        is_difficult = []
        # print('get annotation class_dict : ', self.class_dict)
        for object in objects:
            class_name = object.find('name').text.strip()  # .lower()
            # print(f'_get_annotation : {class_name}')
            # print('get annotation class_name : ', class_name)
            # we're only concerned with clases in our list
            if class_name in self.class_dict:
                bbox = object.find('bndbox')
                # print('get annotation bbox : ', bbox)
                # sku dataset format follows Matlab, in which indexes start from 0
                x1 = float(bbox.find('xmin').text) - 1
                y1 = float(bbox.find('ymin').text) - 1
                x2 = float(bbox.find('xmax').text) - 1
                y2 = float(bbox.find('ymax').text) - 1
                boxes.append([x1, y1, x2, y2])

                labels.append(self.class_dict[class_name])
                is_difficult_str = object.find('difficult').text
                is_difficult.append(int(is_difficult_str) if is_difficult_str else 0)
            else:
                print(f'{image_id} get annotation {class_name} not in class_dict ')

        # print('get annotation return : ',boxes, labels, is_difficult)
        return (np.array(boxes, dtype=np.float32),
                np.array(labels, dtype=np.int64),
                np.array(is_difficult, dtype=np.uint8))

    def _read_image(self, image_id):
        image_file = self.root / f"JPEGImages/{image_id}.jpg"
        image = cv2.imread(str(image_file))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image



if __name__ == '__main__':
    class_names = ['BACKGROUND']
    spath = r'E:\Project\product_data\train'
    slist = os.listdir(spath)
    slist = tuple(slist)
    with open(r'E:\Project\product_data\labels.txt', 'w') as p:
        for sl in slist:
            p.write(sl + ',\n')

